{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Implementing Softmax Exploration\n",
    "\n",
    "Now, let's learn how to implement the softmax exploration to find the best arm.\n",
    "\n",
    "First, let's import the necessary libraries:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "import gym_bandits\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating the bandit environment"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's take the same two-armed bandit we saw in the epsilon-greedy section: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = gym.make(\"BanditTwoArmedHighLowFixed-v0\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's check the probability distribution of the arm:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.8, 0.2]\n"
     ]
    }
   ],
   "source": [
    "print(env.p_dist)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can observe that with arm 1 we win the game with 80% probability and with arm 2 we\n",
    "win the game with 20% probability. Here, the best arm is arm 1, as with arm 1 we win the\n",
    "game 80% probability. Now, let's see how to find this best arm using the softmax exploration method\n",
    "method. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize the variables\n",
    "\n",
    "First, let's initialize the variables:\n",
    "\n",
    "Initialize the `count` for storing the number of times an arm is pulled:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "count = np.zeros(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialize the `sum_rewards` for storing the sum of rewards of each arm:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "sum_rewards = np.zeros(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialize the `Q` for storing the average reward of each arm:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "Q = np.zeros(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define `num_rounds` - number of rounds (iterations):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "num_rounds = 100"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Defining the softmax exploration function\n",
    "\n",
    "Now, let's define the softmax function with temperature `T` as:\n",
    "\n",
    "$$P_t(a) = \\frac{\\text{exp}(Q_t(a)/T)} {\\sum_{i=1}^n \\text{exp}(Q_t(i)/T)} $$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def softmax(T):\n",
    "    \n",
    "    #compute the probability of each arm based on the above equation\n",
    "    denom = sum([np.exp(i/T) for i in Q]) \n",
    "    probs = [np.exp(i/T)/denom for i in Q]\n",
    "    \n",
    "    #select the arm based on the computed probability distribution of arms\n",
    "    arm = np.random.choice(env.action_space.n, p=probs)\n",
    "    \n",
    "    return arm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Start pulling the arm\n",
    "\n",
    "Now, let's play the game and try to find the best arm using the softmax exploration method."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's begin by setting the temperature `T` to a high number, say 50:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "T = 50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for i in range(num_rounds):\n",
    "    \n",
    "    #select the arm based on the softmax exploration method\n",
    "    arm = softmax(T)\n",
    "\n",
    "    #pull the arm and store the reward and next state information\n",
    "    next_state, reward, done, info = env.step(arm) \n",
    "\n",
    "    #increment the count of the arm by 1\n",
    "    count[arm] += 1\n",
    "    \n",
    "    #update the sum of rewards of the arm\n",
    "    sum_rewards[arm]+=reward\n",
    "\n",
    "    #update the average reward of the arm\n",
    "    Q[arm] = sum_rewards[arm]/count[arm]\n",
    "    \n",
    "    #reduce the temperature\n",
    "    T = T*0.99"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After all the rounds, we look at the average reward obtained from each of the arms:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.84090909 0.17857143]\n"
     ]
    }
   ],
   "source": [
    "print(Q)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we can select the optimal arm as the one which has a maximum average reward. Since the arm 1 has a maximum average reward than the arm 2, our optimal arm will be\n",
    "arm 1. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The optimal arm is arm 1\n"
     ]
    }
   ],
   "source": [
    "print('The optimal arm is arm {}'.format(np.argmax(Q)+1))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
